#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/9/5 21:21
@File: human_matting.py
@IDE: pycharm
@Description:
    人像抠图
"""
import numpy as np
from PIL import Image
import onnxruntime
from .tensor2numpy import NNormalize, NTo_Tensor, NUnsqueeze
from .context import Context
import cv2
import os
import logging
from time import time

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


WEIGHTS = {
    "hivision_modnet": os.path.join(
        os.path.dirname(__file__), "weights", "hivision_modnet.onnx"
    ),
    "modnet_photographic_portrait_matting": os.path.join(
        os.path.dirname(__file__),
        "weights",
        "modnet_photographic_portrait_matting.onnx",
    ),
    "mnn_hivision_modnet": os.path.join(
        os.path.dirname(__file__),
        "weights",
        "mnn_hivision_modnet.mnn",
    ),
    "rmbg-1.4": os.path.join(os.path.dirname(__file__), "weights", "rmbg-1.4.onnx"),
    "birefnet-v1-lite": os.path.join(
        os.path.dirname(__file__), "weights", "birefnet-v1-lite.onnx"
    ),
}

ONNX_DEVICE = onnxruntime.get_device()
ONNX_PROVIDER = (
    "CUDAExecutionProvider" if ONNX_DEVICE == "GPU" else "CPUExecutionProvider"
)

HIVISION_MODNET_SESS = None
MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = None
RMBG_SESS = None
BIREFNET_V1_LITE_SESS = None


def load_onnx_model(checkpoint_path, set_cpu=False):
    """安全地加载ONNX模型"""
    # 检查文件是否存在
    if not os.path.exists(checkpoint_path):
        logger.error(f"Model file not found: {checkpoint_path}")
        raise FileNotFoundError(f"Model file not found: {checkpoint_path}")

    # 检查文件权限
    if not os.access(checkpoint_path, os.R_OK):
        logger.error(f"Model file is not readable: {checkpoint_path}")
        raise PermissionError(f"Model file is not readable: {checkpoint_path}")

    # 检查文件大小
    file_size = os.path.getsize(checkpoint_path)
    if file_size == 0:
        logger.error(f"Model file is empty: {checkpoint_path}")
        raise ValueError(f"Model file is empty: {checkpoint_path}")

    providers = (
        ["CUDAExecutionProvider", "CPUExecutionProvider"]
        if ONNX_PROVIDER == "CUDAExecutionProvider"
        else ["CPUExecutionProvider"]
    )

    if set_cpu:
        sess = onnxruntime.InferenceSession(
            checkpoint_path, providers=["CPUExecutionProvider"]
        )
    else:
        try:
            sess = onnxruntime.InferenceSession(checkpoint_path, providers=providers)
        except Exception as e:
            if ONNX_DEVICE == "CUDAExecutionProvider":
                logger.warning(f"Failed to load model with CUDAExecutionProvider: {e}")
                logger.warning("Falling back to CPUExecutionProvider")
                # 尝试使用CPU加载模型
                try:
                    sess = onnxruntime.InferenceSession(
                        checkpoint_path, providers=["CPUExecutionProvider"]
                    )
                except Exception as fallback_e:
                    logger.error(
                        f"Failed to load model with CPUExecutionProvider: {fallback_e}"
                    )
                    raise fallback_e
            else:
                logger.error(f"Failed to load model: {e}")
                raise e  # 如果是CPU执行失败，重新抛出异常

    logger.info(f"Successfully loaded model: {checkpoint_path}")
    return sess


def extract_human(ctx: Context):
    """
    人像抠图
    :param ctx: 上下文
    """
    # 抠图
    matting_image = get_modnet_matting(ctx.processing_image, WEIGHTS["hivision_modnet"])

    # 检查抠图结果
    if matting_image is None:
        raise ValueError(
            "Human matting failed - model returned None. Please check if the model file exists and is valid."
        )

    # 修复抠图
    ctx.processing_image = hollow_out_fix(matting_image)
    ctx.matting_image = ctx.processing_image.copy()


def extract_human_modnet_photographic_portrait_matting(ctx: Context):
    """
    人像抠图
    :param ctx: 上下文
    """
    # 抠图
    matting_image = get_modnet_matting_photographic_portrait_matting(
        ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
    )

    # 检查抠图结果
    if matting_image is None:
        raise ValueError(
            "Human matting failed - model returned None. Please check if the model file exists and is valid."
        )

    # 修复抠图
    ctx.processing_image = matting_image
    ctx.matting_image = ctx.processing_image.copy()


def extract_human_mnn_modnet(ctx: Context):
    matting_image = get_mnn_modnet_matting(
        ctx.processing_image, WEIGHTS["mnn_hivision_modnet"]
    )

    # 检查抠图结果
    if matting_image is None:
        raise ValueError(
            "Human matting failed - model returned None. Please check if the model file exists and is valid."
        )

    ctx.processing_image = hollow_out_fix(matting_image)
    ctx.matting_image = ctx.processing_image.copy()


def extract_human_rmbg(ctx: Context):
    matting_image = get_rmbg_matting(ctx.processing_image, WEIGHTS["rmbg-1.4"])

    # 检查抠图结果
    if matting_image is None:
        raise ValueError(
            "Human matting failed - model returned None. Please check if the model file exists and is valid."
        )

    ctx.processing_image = matting_image
    ctx.matting_image = ctx.processing_image.copy()


# def extract_human_birefnet_portrait(ctx: Context):
#     matting_image = get_birefnet_portrait_matting(
#         ctx.processing_image, WEIGHTS["birefnet-portrait"]
#     )
#     ctx.processing_image = matting_image
#     ctx.matting_image = ctx.processing_image.copy()


def extract_human_birefnet_lite(ctx: Context):
    matting_image = get_birefnet_portrait_matting(
        ctx.processing_image, WEIGHTS["birefnet-v1-lite"]
    )

    # 检查抠图结果
    if matting_image is None:
        raise ValueError(
            "Human matting failed - model returned None. Please check if the model file exists and is valid."
        )

    ctx.processing_image = matting_image
    ctx.matting_image = ctx.processing_image.copy()


def hollow_out_fix(src: np.ndarray) -> np.ndarray:
    """
    修补抠图区域，作为抠图模型精度不够的补充
    :param src:
    :return:
    """
    b, g, r, a = cv2.split(src)
    src_bgr = cv2.merge((b, g, r))
    # -----------padding---------- #
    add_area = np.zeros((10, a.shape[1]), np.uint8)
    a = np.vstack((add_area, a, add_area))
    add_area = np.zeros((a.shape[0], 10), np.uint8)
    a = np.hstack((add_area, a, add_area))
    # -------------end------------ #
    _, a_threshold = cv2.threshold(a, 127, 255, 0)
    a_erode = cv2.erode(
        a_threshold,
        kernel=cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)),
        iterations=3,
    )
    contours, hierarchy = cv2.findContours(
        a_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
    )
    contours = [x for x in contours]
    # contours = np.squeeze(contours)
    contours.sort(key=lambda c: cv2.contourArea(c), reverse=True)
    a_contour = cv2.drawContours(np.zeros(a.shape, np.uint8), contours[0], -1, 255, 2)
    # a_base = a_contour[1:-1, 1:-1]
    h, w = a.shape[:2]
    mask = np.zeros(
        [h + 2, w + 2], np.uint8
    )  # mask 必须行和列都加 2，且必须为 uint8 单通道阵列
    cv2.floodFill(a_contour, mask=mask, seedPoint=(0, 0), newVal=255)
    a = cv2.add(a, 255 - a_contour)
    return cv2.merge((src_bgr, a[10:-10, 10:-10]))


def image2bgr(input_image):
    if len(input_image.shape) == 2:
        input_image = input_image[:, :, None]
    if input_image.shape[2] == 1:
        result_image = np.repeat(input_image, 3, axis=2)
    elif input_image.shape[2] == 4:
        result_image = input_image[:, :, 0:3]
    else:
        result_image = input_image

    return result_image


def read_modnet_image(input_image, ref_size=512):
    im = Image.fromarray(np.uint8(input_image))
    width, length = im.size[0], im.size[1]
    im = np.asarray(im)
    im = image2bgr(im)
    im = cv2.resize(im, (ref_size, ref_size), interpolation=cv2.INTER_AREA)
    im = NNormalize(im, mean=np.array([0.5, 0.5, 0.5]), std=np.array([0.5, 0.5, 0.5]))
    im = NUnsqueeze(NTo_Tensor(im))

    return im, width, length


def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
    global HIVISION_MODNET_SESS

    # 输入验证
    if input_image is None:
        logger.error("Input image is None")
        return None

    if not isinstance(input_image, np.ndarray):
        logger.error("Input image is not a numpy array")
        return None

    if input_image.size == 0:
        logger.error("Input image is empty")
        return None

    # 如果RUN_MODE不是野兽模式，则不加载模型
    if HIVISION_MODNET_SESS is None:
        try:
            HIVISION_MODNET_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
        except Exception as e:
            logger.error(f"Failed to load modnet model: {e}")
            return None

    try:
        input_name = HIVISION_MODNET_SESS.get_inputs()[0].name
        output_name = HIVISION_MODNET_SESS.get_outputs()[0].name

        im, width, length = read_modnet_image(
            input_image=input_image, ref_size=ref_size
        )

        matte = HIVISION_MODNET_SESS.run([output_name], {input_name: im})
        matte = (matte[0] * 255).astype("uint8")
        matte = np.squeeze(matte)
        mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
        b, g, r = cv2.split(np.uint8(input_image))

        output_image = cv2.merge((b, g, r, mask))

        # 如果RUN_MODE不是野兽模式，则释放模型
        if os.getenv("RUN_MODE") != "beast":
            HIVISION_MODNET_SESS = None

        return output_image
    except Exception as e:
        logger.error(f"Error during modnet matting inference: {e}")
        return None


def get_modnet_matting_photographic_portrait_matting(
    input_image, checkpoint_path, ref_size=512
):
    global MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS

    # 输入验证
    if input_image is None:
        logger.error("Input image is None")
        return None

    if not isinstance(input_image, np.ndarray):
        logger.error("Input image is not a numpy array")
        return None

    if input_image.size == 0:
        logger.error("Input image is empty")
        return None

    # 如果RUN_MODE不是野兽模式，则不加载模型
    if MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS is None:
        try:
            MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = load_onnx_model(
                checkpoint_path, set_cpu=True
            )
        except Exception as e:
            logger.error(f"Failed to load modnet photographic portrait model: {e}")
            return None

    try:
        input_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_inputs()[0].name
        output_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_outputs()[0].name

        im, width, length = read_modnet_image(
            input_image=input_image, ref_size=ref_size
        )

        matte = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.run(
            [output_name], {input_name: im}
        )
        matte = (matte[0] * 255).astype("uint8")
        matte = np.squeeze(matte)
        mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
        b, g, r = cv2.split(np.uint8(input_image))

        output_image = cv2.merge((b, g, r, mask))

        # 如果RUN_MODE不是野兽模式，则释放模型
        if os.getenv("RUN_MODE") != "beast":
            MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = None

        return output_image
    except Exception as e:
        logger.error(
            f"Error during modnet photographic portrait matting inference: {e}"
        )
        return None


def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
    global RMBG_SESS

    # 输入验证
    if input_image is None:
        logger.error("Input image is None")
        return None

    if not isinstance(input_image, np.ndarray):
        logger.error("Input image is not a numpy array")
        return None

    if input_image.size == 0:
        logger.error("Input image is empty")
        return None

    def resize_rmbg_image(image):
        image = image.convert("RGB")
        model_input_size = (ref_size, ref_size)
        image = image.resize(model_input_size, Image.BILINEAR)
        return image

    if RMBG_SESS is None:
        try:
            RMBG_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
        except Exception as e:
            logger.error(f"Failed to load RMBG model: {e}")
            return None

    try:
        orig_image = Image.fromarray(input_image)
        image = resize_rmbg_image(orig_image)
        im_np = np.array(image).astype(np.float32)
        im_np = im_np.transpose(2, 0, 1)  # Change to CxHxW format
        im_np = np.expand_dims(im_np, axis=0)  # Add batch dimension
        im_np = im_np / 255.0  # Normalize to [0, 1]
        im_np = (im_np - 0.5) / 0.5  # Normalize to [-1, 1]

        # Inference
        result = RMBG_SESS.run(None, {RMBG_SESS.get_inputs()[0].name: im_np})[0]

        # Post process
        result = np.squeeze(result)
        ma = np.max(result)
        mi = np.min(result)
        result = (result - mi) / (ma - mi)  # Normalize to [0, 1]

        # Convert to PIL image
        im_array = (result * 255).astype(np.uint8)
        pil_im = Image.fromarray(
            im_array, mode="L"
        )  # Ensure mask is single channel (L mode)

        # Resize the mask to match the original image size
        pil_im = pil_im.resize(orig_image.size, Image.BILINEAR)

        # Paste the mask on the original image
        new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
        new_im.paste(orig_image, mask=pil_im)

        # 如果RUN_MODE不是野兽模式，则释放模型
        if os.getenv("RUN_MODE") != "beast":
            RMBG_SESS = None

        return np.array(new_im)
    except Exception as e:
        logger.error(f"Error during RMBG matting inference: {e}")
        return None


def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512):
    # 输入验证
    if input_image is None:
        logger.error("Input image is None")
        return None

    if not isinstance(input_image, np.ndarray):
        logger.error("Input image is not a numpy array")
        return None

    if input_image.size == 0:
        logger.error("Input image is empty")
        return None

    if not os.path.exists(checkpoint_path):
        logger.error(f"Model file not found: {checkpoint_path}")
        return None

    try:
        import MNN.expr as expr
        import MNN.nn as nn
    except ImportError as e:
        logger.error("The MNN module is not installed")
        raise ImportError(
            "The MNN module is not installed or there was an import error. Please ensure that the MNN library is installed by using the command 'pip install mnn'."
        ) from e

    try:
        config = {}
        config["precision"] = "low"  # 当硬件支持（armv8.2）时使用fp16推理
        config["backend"] = 0  # CPU
        config["numThread"] = 4  # 线程数
        im, width, length = read_modnet_image(input_image, ref_size=512)
        rt = nn.create_runtime_manager((config,))
        net = nn.load_module_from_file(
            checkpoint_path, ["input1"], ["output1"], runtime_manager=rt
        )
        input_var = expr.convert(im, expr.NCHW)
        output_var = net.forward(input_var)
        matte = expr.convert(output_var, expr.NCHW)
        matte = matte.read()  # var转换为np
        matte = (matte * 255).astype("uint8")
        matte = np.squeeze(matte)
        mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
        b, g, r = cv2.split(np.uint8(input_image))

        output_image = cv2.merge((b, g, r, mask))

        return output_image
    except Exception as e:
        logger.error(f"Error during MNN modnet matting inference: {e}")
        return None


def get_birefnet_portrait_matting(input_image, checkpoint_path, ref_size=512):
    global BIREFNET_V1_LITE_SESS

    # 输入验证
    if input_image is None:
        logger.error("Input image is None")
        return None

    if not isinstance(input_image, np.ndarray):
        logger.error("Input image is not a numpy array")
        return None

    if input_image.size == 0:
        logger.error("Input image is empty")
        return None

    def transform_image(image):
        image = image.resize((1024, 1024))  # Resize to 1024x1024
        image = (
            np.array(image, dtype=np.float32) / 255.0
        )  # Convert to numpy array and normalize to [0, 1]
        image = (image - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]  # Normalize
        image = np.transpose(image, (2, 0, 1))  # Change from (H, W, C) to (C, H, W)
        image = np.expand_dims(image, axis=0)  # Add batch dimension
        return image.astype(np.float32)  # Ensure the output is float32

    try:
        orig_image = Image.fromarray(input_image)
        input_images = transform_image(
            orig_image
        )  # This will already have the correct shape

        # 如果RUN_MODE不是野兽模式，则不加载模型
        if BIREFNET_V1_LITE_SESS is None:
            if ONNX_DEVICE == "GPU":
                logger.info("onnxruntime-gpu已安装，尝试使用CUDA加载模型")
                try:
                    import torch
                except ImportError:
                    logger.warning(
                        "torch未安装，尝试直接使用onnxruntime-gpu加载模型，这需要配置好CUDA和cuDNN"
                    )
                BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path)
            else:
                BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path, set_cpu=True)

        input_name = BIREFNET_V1_LITE_SESS.get_inputs()[0].name

        time_st = time()
        pred_onnx = BIREFNET_V1_LITE_SESS.run(None, {input_name: input_images})[
            -1
        ]  # Use float32 input
        pred_onnx = np.squeeze(pred_onnx)  # Use numpy to squeeze
        result = 1 / (1 + np.exp(-pred_onnx))  # Sigmoid function using numpy
        logger.info(f"Inference time: {time() - time_st:.4f} seconds")

        # Convert to PIL image
        im_array = (result * 255).astype(np.uint8)
        pil_im = Image.fromarray(
            im_array, mode="L"
        )  # Ensure mask is single channel (L mode)

        # Resize the mask to match the original image size
        pil_im = pil_im.resize(orig_image.size, Image.BILINEAR)

        # Paste the mask on the original image
        new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
        new_im.paste(orig_image, mask=pil_im)

        # 如果RUN_MODE不是野兽模式，则释放模型
        if os.getenv("RUN_MODE") != "beast":
            BIREFNET_V1_LITE_SESS = None

        return np.array(new_im)
    except Exception as e:
        logger.error(f"Error during BiRefNet matting inference: {e}")
        return None
